{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Text Encoding with DJL Spark Support\n",
    "\n",
    "In this example, we will use Jupyter Notebook to run Text Encoding with DJL Spark extension on Scala. To execute this Scala kernel successfully, you need to install [Almond](https://almond.sh/), a Scala kernel for Jupyter Notebook. Almond provide extensive functionalities for Scala and Spark applications.\n",
    "\n",
    "[Almond installation instruction](https://almond.sh/docs/quick-start-install) (Note: only Scala 2.12 are tested)\n",
    "\n",
    "After that, you can start with DJL's Scala notebook.\n",
    "\n",
    "\n",
    "## Import dependencies\n",
    "\n",
    "Firstly, let's import the depdendencies we need."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import $ivy.`org.apache.spark::spark-sql:3.3.2`\n",
    "import $ivy.`ai.djl:api:0.27.0`\n",
    "import $ivy.`ai.djl.spark:spark_2.12:0.27.0`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can import the packages we need to use. In the last two lines, we disabled the Spark logging in order to avoid polluting your cell outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.apache.spark.sql.NotebookSparkSession\n",
    "import ai.djl.spark.task.text.{TextDecoder, TextEncoder}\n",
    "import org.apache.spark.sql.SparkSession\n",
    "\n",
    "import org.apache.log4j.{Level, Logger}\n",
    "Logger.getLogger(\"org\").setLevel(Level.OFF) // avoid too much message popping out\n",
    "Logger.getLogger(\"ai\").setLevel(Level.OFF) // avoid too much message popping out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start Spark application\n",
    "\n",
    "We can create a `NotebookSparkSession` through the Almond Spark plugin. It will internally apply all necessary jars to each of the worker node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Create Spark session\n",
    "val spark = {\n",
    "  NotebookSparkSession.builder()\n",
    "    .master(\"local[*]\")\n",
    "    .getOrCreate()\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create DataFrame with text values using Spark library:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val df = spark.createDataFrame(Seq(\n",
    "  (1, \"Hello, y'all! How are you?\"),\n",
    "  (2, \"Hello to you too!\"),\n",
    "  (3, \"I'm fine, thank you!\")\n",
    ")).toDF(\"id\", \"text\")\n",
    "df.show(truncate=false)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can run encoding on the text. All we need to do is to create a `HuggingFaceTextEncoder`, use the \"bert-base-cased\" tokenizer and run encoding with DJL. The output is StructType column: \"encoded\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val encoder = new TextEncoder()\n",
    "  .setInputCol(\"text\")\n",
    "  .setOutputCol(\"encoded\")\n",
    "  .setHfModelId(\"bert-base-cased\")\n",
    "var encDf = encoder.encode(df)\n",
    "encDf.printSchema()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can run decoding on the above encoded text. All we need to do is to create a `HuggingFaceTextDecoder`, use the \"bert-base-cased\" tokenizer and run decoding with DJL. The output is StringType column: \"decoded\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encDf = encDf.select(\"id\", \"text\", \"encoded.*\")\n",
    "val decoder = new TextDecoder()\n",
    "  .setInputCol(\"ids\")\n",
    "  .setOutputCol(\"decoded\")\n",
    "  .setHfModelId(\"bert-base-cased\")\n",
    "var decDf = decoder.decode(encDf)\n",
    "decDf.printSchema()\n",
    "decDf.select(\"id\", \"text\", \"ids\", \"decoded\").show(truncate=false)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Scala",
   "language": "scala",
   "name": "scala"
  },
  "language_info": {
   "codemirror_mode": "text/x-scala",
   "file_extension": ".sc",
   "mimetype": "text/x-scala",
   "name": "scala",
   "nbconvert_exporter": "script",
   "version": "2.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
